import os
import json
import argparse
from typing import Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
try:
    from scipy.stats import wasserstein_distance
except ImportError as e:
    raise ImportError("scipy is required for Wasserstein distance calculation. Install with `pip install scipy`.") from e

# ---------------------------------------------------------------------------
# Configuration & CLI
# ---------------------------------------------------------------------------

def parse_args():
    parser = argparse.ArgumentParser(
        description=(
            "Generate the correlation matrix between persona predictions and "
            "save the heat-map (relative colours) with exact values."
        )
    )
    parser.add_argument(
        "--json_path",
        type=str,
        default=None,
        help="Path to the merged *persona* prediction JSON file produced by the evaluation scripts.",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        default=None,
        help="Directory in which to save the generated figures (png & pdf).",
    )
    return parser.parse_args()


# ---------------------------------------------------------------------------
# Exact correlation matrix values extracted from the image
# ---------------------------------------------------------------------------

def get_exact_correlation_matrix():
    """Return the exact correlation matrix values as shown in the provided image."""
    personas = [
        "18-24_female",
        "18-24_male", 
        "25-34_female",
        "25-34_male",
        "35-44_female",
        "35-44_male",
        "45-54_female",
        "45-54_male",
        "55+_female",
        "55+_male"
    ]
    
    # Exact values extracted from the heatmap image
    correlation_values = [
        [1.00, 0.85, 0.88, 0.73, 0.82, 0.62, 0.81, 0.79, 0.82, 0.76],  # 18-24_female
        [0.85, 1.00, 0.78, 0.75, 0.75, 0.66, 0.77, 0.78, 0.81, 0.79],  # 18-24_male
        [0.88, 0.78, 1.00, 0.76, 0.86, 0.67, 0.85, 0.82, 0.81, 0.76],  # 25-34_female
        [0.73, 0.75, 0.76, 1.00, 0.76, 0.78, 0.78, 0.84, 0.71, 0.81],  # 25-34_male
        [0.82, 0.75, 0.86, 0.76, 1.00, 0.73, 0.85, 0.84, 0.82, 0.79],  # 35-44_female
        [0.62, 0.66, 0.67, 0.78, 0.73, 1.00, 0.77, 0.83, 0.73, 0.82],  # 35-44_male
        [0.81, 0.77, 0.85, 0.78, 0.85, 0.77, 1.00, 0.87, 0.87, 0.85],  # 45-54_female
        [0.79, 0.78, 0.82, 0.84, 0.84, 0.83, 0.87, 1.00, 0.82, 0.88],  # 45-54_male
        [0.82, 0.81, 0.81, 0.71, 0.82, 0.73, 0.87, 0.82, 1.00, 0.86],  # 55+_female
        [0.76, 0.79, 0.76, 0.81, 0.79, 0.82, 0.85, 0.88, 0.86, 1.00]   # 55+_male
    ]
    
    return pd.DataFrame(correlation_values, index=personas, columns=personas)


# ---------------------------------------------------------------------------
# Utilities to load & wrangle predictions (kept for compatibility)
# ---------------------------------------------------------------------------

_DEF_PERSONA_ORDER = [
    "18-24 female",
    "18-24 male",
    "25-34 female",
    "25-34 male",
    "35-44 female",
    "35-44 male",
    "45-54 female",
    "45-54 male",
    "55+ female",
    "55+ male",
]


def _extract_mean_prediction(details: Dict) -> float | None:
    """Return a usable scalar prediction from a persona block.

    Priority:
        1. `mean_prediction` (if not None)
        2. Mean of list under `predictions` / `all_predictions`
    """
    if details is None:
        return None

    # 1. Direct field
    if (val := details.get("mean_prediction")) is not None:
        try:
            return float(val)
        except Exception:
            pass

    # 2. Fallback to list fields
    for key in ("predictions", "all_predictions"):
        if isinstance(details.get(key), list) and details[key]:
            try:
                return float(np.mean(details[key]))
            except Exception:
                continue

    return None


def load_predictions(json_path: str) -> Dict[str, Dict[str, float]]:
    """Load the merged-JSON file and return mapping: persona → {website_id: score}."""
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    predictions: Dict[str, Dict[str, float]] = {p: {} for p in _DEF_PERSONA_ORDER}

    def _website_id(image_path: str) -> str:
        # Examples: "/english_resized/327.png" → "english_327"
        base = os.path.basename(image_path)  # "327.png"
        num, _ = os.path.splitext(base)
        if "english_resized" in image_path:
            return f"english_{num}"
        if "foreign_resized" in image_path:
            return f"foreign_{num}"
        return num

    for entry in data:
        website_id = _website_id(entry.get("image", ""))
        persona_container = entry.get("persona_predictions") or entry.get("persona_responses")
        if not isinstance(persona_container, dict):
            # Fallback: treat any dict-valued field with prediction keys as a persona block
            persona_container = {
                k: v
                for k, v in entry.items()
                if isinstance(v, dict) and any(t in v for t in ("mean_prediction", "predictions", "all_predictions"))
            }
        if not persona_container:
            continue

        for persona, details in persona_container.items():
            score = _extract_mean_prediction(details)
            if score is None:
                continue
            predictions.setdefault(persona, {})[website_id] = score

    # Remove personas that ended up empty
    predictions = {p: m for p, m in predictions.items() if m}
    return predictions


# ---------------------------------------------------------------------------
# Wasserstein distance computation
# ---------------------------------------------------------------------------

_PERSONA_ORDER_UNDERSCORE = [
    "18-24_female",
    "18-24_male",
    "25-34_female",
    "25-34_male",
    "35-44_female",
    "35-44_male",
    "45-54_female",
    "45-54_male",
    "55+_female",
    "55+_male",
]

_DISPLAY_NAMES = [
    "18-24 female",
    "18-24 male",
    "25-34 female",
    "25-34 male",
    "35-44 female",
    "35-44 male",
    "45-54 female",
    "45-54 male",
    "55+ female",
    "55+ male",
]

def compute_wasserstein_matrix(predictions: Dict[str, Dict[str, float]]) -> pd.DataFrame:
    """Compute pairwise Wasserstein distance between persona prediction distributions."""
    personas = _PERSONA_ORDER_UNDERSCORE
    n = len(personas)
    dist = np.full((n, n), np.nan)

    # Pre-compute score arrays for each persona
    score_lists: dict[str, np.ndarray] = {}
    for p in personas:
        vals = list(predictions.get(p, {}).values())
        score_lists[p] = np.asarray(vals, dtype=float) if vals else np.asarray([], dtype=float)

    for i, p1 in enumerate(personas):
        for j, p2 in enumerate(personas):
            if i == j:
                dist[i, j] = 0.0
            else:
                a, b = score_lists[p1], score_lists[p2]
                if a.size == 0 or b.size == 0:
                    dist[i, j] = np.nan
                else:
                    dist[i, j] = wasserstein_distance(a, b)

    return pd.DataFrame(dist, index=personas, columns=personas)


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

def plot_distance(dist_df: pd.DataFrame, out_dir: str) -> None:
    sns.set_theme(style="white", font_scale=1.1)
    
    # Flip the palette so positive distances map to green hues
    cmap = sns.diverging_palette(240, 120, s=80, l=45, as_cmap=True, n=100)
    
    # ------------------------------------------------------------------
    # Set colour scale so that 0 distance is mapped to white (center)
    # ------------------------------------------------------------------
    finite_vals = dist_df.values[np.isfinite(dist_df.values)]
    vmin = 0.0  # Ensure the scale starts at 0, giving it the whitish colour
    vmax = np.nanmax(finite_vals)
    center_val = 0.0  # Keep 0 as the centre of the diverging palette

    # Evenly spaced ticks between min and max (8 ticks as before)
    ticks = np.linspace(vmin, vmax, num=8) if np.isfinite(vmin) and np.isfinite(vmax) else []
    # Ensure ticks is always a plain Python list for safe comparisons/len checks
    if isinstance(ticks, np.ndarray):
        ticks = ticks.tolist()

    fig, ax = plt.subplots(figsize=(9, 8))
    sns.heatmap(
        dist_df,
        annot=True,
        fmt=".2f",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        center=center_val,
        linewidths=0.5,
        square=True,
        cbar_kws={"label": "Wasserstein Distance", "ticks": ticks},
        ax=ax,
    )

    # Ensure the colour-bar tick labels are formatted consistently
    if ticks:
        cbar = ax.collections[0].colorbar
        cbar.ax.set_yticklabels([f"{t:.2f}" for t in ticks])

    # ------------------------------------------------------------------
    # Highlight the smallest off-diagonal distance with a border
    # ------------------------------------------------------------------
    # Identify the **largest finite** off-diagonal distance
    mask = ~np.eye(dist_df.shape[0], dtype=bool)
    flat_vals = dist_df.values[mask]
    valid_vals = flat_vals[np.isfinite(flat_vals)]
    if valid_vals.size:
        max_off_diag = valid_vals.max()
        for i in range(dist_df.shape[0]):
            for j in range(dist_df.shape[1]):
                if i != j and np.isclose(dist_df.iat[i, j], max_off_diag):
                    rect = patches.Rectangle((j, i), 1, 1, fill=False, edgecolor="yellow", linewidth=3)
                    ax.add_patch(rect)

    # Center the tick marks in each cell
    n_rows, n_cols = dist_df.shape
    ax.set_xticks(np.arange(n_cols) + 0.5)
    ax.set_xticklabels(dist_df.columns, rotation=45, ha='right')

    ax.set_yticks(np.arange(n_rows) + 0.5)
    ax.set_yticklabels(dist_df.index, rotation=0, va='center')

    ax.set_title("Wasserstein Distance Between Persona Predictions", fontsize=16, pad=15)

    plt.tight_layout()
    os.makedirs(out_dir, exist_ok=True)
    out_base = os.path.join(out_dir, "persona_prediction_wasserstein_distance")
    fig.savefig(f"{out_base}.png", dpi=300)
    fig.savefig(f"{out_base}.pdf")
    plt.close(fig)


# ORIGINAL CORRELATION PLOTTING FUNCTION (kept for reference)

def plot_correlation(corr_df: pd.DataFrame, out_dir: str) -> None:
    sns.set_theme(style="white", font_scale=1.1)

    # Use the exact same color scheme as the original but with more granular variations
    # Keep the green→white→blue diverging palette but increase resolution for subtle differences
# Slightly darken the blue side of the palette for stronger contrast
    cmap = sns.diverging_palette(120, 240, s=80, l=45, as_cmap=True, n=100)

    # Define colour-bar ticks so they show exactly the labels in the screenshot
    ticks = [1.00, 0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.65]

    # Set vmin slightly below the lowest tick so the bottom label (0.65) isn't flush with the edge
    vmin, vmax = 0.62, 0.99
    center_val = 0.825  # Slightly above middle for better color separation

    fig, ax = plt.subplots(figsize=(9, 8))
    sns.heatmap(
        corr_df,
        annot=True,
        fmt=".2f",
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        center=center_val,
        linewidths=0.5,
        square=True,
        cbar_kws={"label": "Correlation", "ticks": ticks},
        ax=ax,
    )

    # Ensure the colour-bar tick labels are formatted consistently
    cbar = ax.collections[0].colorbar
    cbar.ax.set_yticklabels([f"{t:.2f}" for t in ticks])

    # ------------------------------------------------------------------
    # Highlight the strongest off-diagonal correlation(s) with a border
    # ------------------------------------------------------------------
    mask = ~np.eye(corr_df.shape[0], dtype=bool)
    max_off_diag = corr_df.values[mask].max()

    for i in range(corr_df.shape[0]):
        for j in range(corr_df.shape[1]):
            if i != j and np.isclose(corr_df.iat[i, j], max_off_diag):
                rect = patches.Rectangle((j, i), 1, 1, fill=False, edgecolor="yellow", linewidth=3)
                ax.add_patch(rect)

    # Center the tick marks in each cell
    n_rows, n_cols = corr_df.shape
    
    # Set x-axis ticks to center of each cell
    ax.set_xticks(np.arange(n_cols) + 0.5)
    ax.set_xticklabels(corr_df.columns, rotation=45, ha='right')
    
    # Set y-axis ticks to center of each cell
    ax.set_yticks(np.arange(n_rows) + 0.5)
    ax.set_yticklabels(corr_df.index, rotation=0, va='center')

    ax.set_title("Correlation Between Persona Predictions", fontsize=16, pad=15)

    plt.tight_layout()
    os.makedirs(out_dir, exist_ok=True)
    out_base = os.path.join(out_dir, "persona_prediction_correlation")
    fig.savefig(f"{out_base}.png", dpi=300)
    fig.savefig(f"{out_base}.pdf")
    plt.close(fig)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    args = parse_args()

    base_dir = os.path.dirname(os.path.abspath(__file__))
    out_dir = args.out_dir or os.path.join(base_dir, "outputs")

    # Load predictions JSON and compute Wasserstein distance matrix

    json_path = args.json_path or "/path/to/gpt_4o_webaes_static_final.json"
    predictions = load_predictions(json_path)
    dist_df = compute_wasserstein_matrix(predictions)
    
    # Convert persona names to match the original format for display
    display_names = [
        "18-24 female",
        "18-24 male",
        "25-34 female", 
        "25-34 male",
        "35-44 female",
        "35-44 male",
        "45-54 female",
        "45-54 male",
        "55+ female",
        "55+ male"
    ]
    
    dist_df.index = _DISPLAY_NAMES
    dist_df.columns = _DISPLAY_NAMES

    # ----------------------------------------------------
    # Plot & save
    # ----------------------------------------------------
    plot_distance(dist_df, out_dir)

    print(f"✅ Wasserstein distance heat-map saved to: {out_dir}")
    print(f"Using predictions from: {json_path}")


if __name__ == "__main__":
    main()